function [dec, lambda_grid, pweights, U_f, U_m, U_f0, U_m0, U_ftot, U_mtot] = solution_backwards_ud(n_g, T, lt, dt, H_max, disc, det_earn_f, det_earn_m, stoch_earn_f, stoch_earn_m, m_eta_f, m_eta_m, n_eta_f, n_eta_m, psi, home_prod, th1, match_qual, m_th, n_th, gamma, L_max, lambda, sigma_f, sigma_m)

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% The impact of divorce laws on the equilibrium in the marriage market.
% Reynoso
% April 2024
%
% This function takes the parameters of the model and solves the
% model by backwards induction under Unilateral Divorce.
%
% Data: PSID 1968-1992
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

%--------------------------- Preliminaries ------------------------------%

l=linspace(0,1,L_max);
lambda_grid=l;
lambda_grid(1)=0.000001;
lambda_grid(L_max)=0.99999;

dec=zeros(n_g,L_max,H_max,n_eta_f,n_eta_m,n_th,T); 
pweights=zeros(n_g,L_max,H_max,n_eta_f,n_eta_m,n_th,T); 

U_f=zeros(n_g,1);
U_m=zeros(n_g,1);
U_f0=zeros(n_g,1);
U_m0=zeros(n_g,1);
U_ftot=zeros(n_g,1);
U_mtot=zeros(n_g,1);

%--------------------------- Value functions -----------------------------%

for g=1:n_g 

[lw_grid_m, P_eta_m, ~, ~] = wageproc_fn(stoch_earn_m(g,:), m_eta_m, n_eta_m);
[lw_grid_f, P_eta_f, ~, ~] = wageproc_fn(stoch_earn_f(g,:), m_eta_f, n_eta_f);

Pwage=zeros(n_eta_f,n_eta_m,n_eta_f,n_eta_m); 

for i=1:n_eta_f 
    for j=1:n_eta_m 
        Pwage(i,j,:,:)=P_eta_f(i,:)'*P_eta_m(j,:);                                    
                                                   
   end
end

[thtau_grid, P_thtau, ~, ~] = thetaproc(match_qual(g,:), m_th, n_th);

P_wg_th=zeros(n_eta_f,n_eta_m,n_th,n_eta_f,n_eta_m,n_th); 

for i=1:n_eta_f 
    for j=1:n_eta_m 
        for k=1:n_th 
            for kp=1:n_th
        
                P_wg_th(i,j,k,:,:,kp)=Pwage(i,j,:,:)*P_thtau(k,kp);                                         
            end
        end                                          
   end
end

%--- Divorce

A=zeros(n_eta_f,n_eta_m);
B_Af=zeros(n_eta_f,n_eta_m);
B_Am=zeros(n_eta_f,n_eta_m);
atkf=zeros(H_max,n_eta_f,n_eta_m,T);
atkm=zeros(H_max,n_eta_f,n_eta_m,T);
E_Af=zeros(H_max,n_eta_f,n_eta_m,T); 
E_Am=zeros(H_max,n_eta_f,n_eta_m,T);
val_Df=zeros(H_max,n_eta_f,n_eta_m,T); 
val_Dm=zeros(H_max,n_eta_f,n_eta_m,T); 

t=T;

for H=1:H_max
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            
            [val_Df(H,wf,wm,t),val_Dm(H,wf,wm,t),~]=autarky(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:), gamma,dt); 
                       
        end
    end
end

t=T-1; 

while (t>=lt)
for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            [atkf(H,wf,wm,t),atkm(H,wf,wm,t),~]=autarky(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:), gamma, dt); 
            A(:,:)=Pwage(wf,wm,:,:);
             B_Af(:,:)=val_Df(H,:,:,t+1);
            B_Am(:,:)=val_Dm(H,:,:,t+1); 
            E_Af(H,wf,wm,t)=sum(sum(A.*B_Af)); 
            E_Am(H,wf,wm,t)=sum(sum(A.*B_Am));
           val_Df(H,wf,wm,t)=atkf(H,wf,wm,t)+disc*E_Af(H,wf,wm,t);
           val_Dm(H,wf,wm,t)=atkm(H,wf,wm,t)+disc*E_Am(H,wf,wm,t);
        end
    end
end

t=t-1;

end

%--- Marriage
stayf0=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T); 
staym0=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T);
stayf1=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T);
staym1=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T);
val=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T); 
E_Mf0=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T);
E_Mf1=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T);
E_Mm0=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T);
E_Mm1=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T);
val_matf=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T,2); 
val_matm=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T,2); 
dec_M=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T); 
dec_ud=ones(L_max, H_max,n_eta_f,n_eta_m,n_th,T)*3; 
valM_f=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T); 
valM_m=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T);
val_f=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T); 
val_m=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T);

for t=1:T
for p=1:L_max
    for th=1:n_th
val_f(p,:,:,:,th,t)=val_Df(:,:,:,t); 
val_m(p,:,:,:,th,t)=val_Dm(:,:,:,t);
    end
end
end

pareto_weight_index=zeros(L_max, H_max,n_eta_f,n_eta_m,n_th,T); 
[~, pareto_weight_index(:,:,:,:,:,:)]=min(abs(lambda_grid-lambda(g)));
Awt=zeros(n_eta_f,n_eta_m,n_th);
B_Mf0=zeros(n_eta_f,n_eta_m,n_th); 
B_Mm0=zeros(n_eta_f,n_eta_m,n_th); 
B_Mf1=zeros(n_eta_f,n_eta_m,n_th); 
B_Mm1=zeros(n_eta_f,n_eta_m,n_th);

t=T;

for p=1:L_max 
for H=1:H_max
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th    
                
            [stayf0(p,H,wf,wm,th,t),staym0(p,H,wf,wm,th,t),stayf1(p,H,wf,wm,th,t),staym1(p,H,wf,wm,th,t)]=stayM(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:),home_prod(g,:),psi(g),th1(g),thtau_grid(th),lambda_grid(p),dt); 
            
            end 
        end
    end
end
end

val_matf(:,:,:,:,:,t,1)=stayf0(:,:,:,:,:,t); 
val_matf(:,:,:,:,:,t,2)=stayf1(:,:,:,:,:,t);
val_matm(:,:,:,:,:,t,1)=staym0(:,:,:,:,:,t); 
val_matm(:,:,:,:,:,t,2)=staym1(:,:,:,:,:,t);

for p=1:L_max 
[val(p,:,:,:,:,t),dec_M(p,:,:,:,:,t)]=max(lambda_grid(p)*val_matf(p,:,:,:,:,t,:)+(1-lambda_grid(p))*val_matm(p,:,:,:,:,t,:),[],7); 
end

for p=1:L_max
for H=1:H_max
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                
                d= dec_M(p,H,wf,wm,th,t);
                valM_f(p,H,wf,wm,th,t)= val_matf(p,H,wf,wm,th,t,d); 
                valM_m(p,H,wf,wm,th,t)= val_matm(p,H,wf,wm,th,t,d); 
            
            end
        end
    end
end
end

for p=1:L_max 
for H=1:H_max
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                if (valM_f(p,H,wf,wm,th,t) >= val_Df(H,wf,wm,t)) && (valM_m(p,H,wf,wm,th,t) >= val_Dm(H,wf,wm,t));
                           val_f(p,H,wf,wm,th,t) = valM_f(p,H,wf,wm,th,t);
                           val_m(p,H,wf,wm,th,t) = valM_m(p,H,wf,wm,th,t);
                           dec_ud(p,H,wf,wm,th,t) = dec_M(p,H,wf,wm,th,t); 
                           pareto_weight_index(p,H,wf,wm,th,t)=p;
                end
                 if (valM_f(p,H,wf,wm,th,t) >= val_Df(H,wf,wm,t)) && (valM_m(p,H,wf,wm,th,t) < val_Dm(H,wf,wm,t));
                   for bkx=p-1:-1:1 
                       if (valM_f(bkx,H,wf,wm,th,t) >= val_Df(H,wf,wm,t)) && (valM_m(bkx,H,wf,wm,th,t) >= val_Dm(H,wf,wm,t));
                           val_f(p,H,wf,wm,th,t) = valM_f(bkx,H,wf,wm,th,t); 
                           val_m(p,H,wf,wm,th,t) = valM_m(bkx,H,wf,wm,th,t); 
                           dec_ud(p,H,wf,wm,th,t) = dec_M(bkx,H,wf,wm,th,t); 
                           pareto_weight_index(p,H,wf,wm,th,t)=bkx; 
                           break 
                       end
                   end
                end 
                
                if (valM_f(p,H,wf,wm,th,t) < val_Df(H,wf,wm,t)) && (valM_m(p,H,wf,wm,th,t) >= val_Dm(H,wf,wm,t));
                   for fdx=p+1:L_max
                       if (valM_f(fdx,H,wf,wm,th,t) >= val_Df(H,wf,wm,t)) && (valM_m(fdx,H,wf,wm,th,t) >= val_Dm(H,wf,wm,t));
                           val_f(p,H,wf,wm,th,t) = valM_f(fdx,H,wf,wm,th,t);
                           val_m(p,H,wf,wm,th,t) = valM_m(fdx,H,wf,wm,th,t);
                           dec_ud(p,H,wf,wm,th,t) = dec_M(fdx,H,wf,wm,th,t);
                           pareto_weight_index(p,H,wf,wm,th,t)=fdx; 
                           break 
                       end 
                   end
                end 
                
               
                
            end
        end
    end
end
end

t=T-1; 

while (t>lt) 

    for p=1:L_max
for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
               
            [stayf0(p,H,wf,wm,th,t),staym0(p,H,wf,wm,th,t),stayf1(p,H,wf,wm,th,t),staym1(p,H,wf,wm,th,t)]=stayM(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:),home_prod(g,:),psi(g),th1(g),thtau_grid(th),lambda_grid(p), dt); 
           
            Awt(:,:,:)=P_wg_th(wf,wm,th,:,:,:);
            B_Mf0(:,:,:)=val_f(p,H,:,:,:,t+1);
            E_Mf0(p,H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mf0))); 
            B_Mm0(:,:,:)=val_m(p,H,:,:,:,t+1);
            E_Mm0(p,H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mm0))); 
            
            B_Mf1(:,:,:)=val_f(p,H+1,:,:,:,t+1);
            E_Mf1(p,H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mf1))); 
            B_Mm1(:,:,:)=val_m(p,H+1,:,:,:,t+1);
            E_Mm1(p,H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mm1))); 
           
            val_matf(p,H,wf,wm,th,t,1)=stayf0(p,H,wf,wm,th,t) + disc*E_Mf0(p,H,wf,wm,th,t); 
            val_matf(p,H,wf,wm,th,t,2)=stayf1(p,H,wf,wm,th,t) + disc*E_Mf1(p,H,wf,wm,th,t);
            val_matm(p,H,wf,wm,th,t,1)=staym0(p,H,wf,wm,th,t) + disc*E_Mm0(p,H,wf,wm,th,t); 
            val_matm(p,H,wf,wm,th,t,2)=staym1(p,H,wf,wm,th,t) + disc*E_Mm1(p,H,wf,wm,th,t);
             
            end                                 
        end
    end
end
    end
    
for p=1:L_max 
[val(p,:,:,:,:,t),dec_M(p,:,:,:,:,t)]=max(lambda_grid(p)*val_matf(p,:,:,:,:,t,:)+(1-lambda_grid(p))*val_matm(p,:,:,:,:,t,:),[],7); 
end   

for p=1:L_max
for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                
                d= dec_M(p,H,wf,wm,th,t);
                valM_f(p,H,wf,wm,th,t)= val_matf(p,H,wf,wm,th,t,d); 
                valM_m(p,H,wf,wm,th,t)= val_matm(p,H,wf,wm,th,t,d); 
                
            end
        end
    end
end
end

for p=1:L_max 
for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                if (valM_f(p,H,wf,wm,th,t) >= val_Df(H,wf,wm,t)) && (valM_m(p,H,wf,wm,th,t) >= val_Dm(H,wf,wm,t));
                           val_f(p,H,wf,wm,th,t) = valM_f(p,H,wf,wm,th,t);
                           val_m(p,H,wf,wm,th,t) = valM_m(p,H,wf,wm,th,t);
                           dec_ud(p,H,wf,wm,th,t) = dec_M(p,H,wf,wm,th,t);
                           pareto_weight_index(p,H,wf,wm,th,t)=p;
                end
                if (valM_f(p,H,wf,wm,th,t) >= val_Df(H,wf,wm,t)) && (valM_m(p,H,wf,wm,th,t) < val_Dm(H,wf,wm,t));
                   for bkx=p-1:-1:1 
                       if (valM_f(bkx,H,wf,wm,th,t) >= val_Df(H,wf,wm,t)) && (valM_m(bkx,H,wf,wm,th,t) >= val_Dm(H,wf,wm,t));
                           val_f(p,H,wf,wm,th,t) = valM_f(bkx,H,wf,wm,th,t); 
                           val_m(p,H,wf,wm,th,t) = valM_m(bkx,H,wf,wm,th,t);  
                           dec_ud(p,H,wf,wm,th,t) = dec_M(bkx,H,wf,wm,th,t); 
                           pareto_weight_index(p,H,wf,wm,th,t)=bkx; 
                           break 
                       end
                   end
                end 
                if (valM_f(p,H,wf,wm,th,t) < val_Df(H,wf,wm,t)) && (valM_m(p,H,wf,wm,th,t) >= val_Dm(H,wf,wm,t));
                   for fdx=p+1:L_max 
                       if (valM_f(fdx,H,wf,wm,th,t) >= val_Df(H,wf,wm,t)) && (valM_m(fdx,H,wf,wm,th,t) >= val_Dm(H,wf,wm,t));
                           val_f(p,H,wf,wm,th,t) = valM_f(fdx,H,wf,wm,th,t);
                           val_m(p,H,wf,wm,th,t) = valM_m(fdx,H,wf,wm,th,t);
                           dec_ud(p,H,wf,wm,th,t) = dec_M(fdx,H,wf,wm,th,t); 
                           pareto_weight_index(p,H,wf,wm,th,t)=fdx;
                           break
                       end 
                   end
                end 
                
               
                
            end
        end
    end
end
end

t=t-1;
end 

[~,p]=min(abs(lambda_grid-lambda(g))); 

for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
              
            [stayf0(:,H,wf,wm,th,t),staym0(:,H,wf,wm,th,t),stayf1(:,H,wf,wm,th,t),staym1(:,H,wf,wm,th,t)]=stayM(H,lw_grid_f(wf),lw_grid_m(wm),t,det_earn_f(g,:), det_earn_m(g,:),home_prod(g,:),psi(g),th1(g),thtau_grid(th),lambda(g), dt); 
            Awt(:,:,:)=P_wg_th(wf,wm,th,:,:,:);
            B_Mf0(:,:,:)=val_f(p,H,:,:,:,t+1); 
            E_Mf0(:,H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mf0))); 
            B_Mm0(:,:,:)=val_m(p,H,:,:,:,t+1); 
            E_Mm0(:,H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mm0))); 
            
            B_Mf1(:,:,:)=val_f(p,H+1,:,:,:,t+1);
            E_Mf1(:,H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mf1))); 
            B_Mm1(:,:,:)=val_m(p,H+1,:,:,:,t+1); 
            E_Mm1(:,H,wf,wm,th,t)=sum(sum(sum(Awt.*B_Mm1))); 
                        
            val_matf(:,H,wf,wm,th,t,1)=stayf0(:,H,wf,wm,th,t) + disc*E_Mf0(:,H,wf,wm,th,t); 
            val_matf(:,H,wf,wm,th,t,2)=stayf1(:,H,wf,wm,th,t) + disc*E_Mf1(:,H,wf,wm,th,t);
            val_matm(:,H,wf,wm,th,t,1)=staym0(:,H,wf,wm,th,t) + disc*E_Mm0(:,H,wf,wm,th,t); 
            val_matm(:,H,wf,wm,th,t,2)=staym1(:,H,wf,wm,th,t) + disc*E_Mm1(:,H,wf,wm,th,t);
             
            end                                 
        end
    end
end

[val(:,:,:,:,:,t),dec_M(:,:,:,:,:,t)]=max(lambda(g)*val_matf(:,:,:,:,:,t,:)+(1-lambda(g))*val_matm(:,:,:,:,:,t,:),[],7); 

for H=1:t
    for wf=1:n_eta_f
        for wm=1:n_eta_m
            for th=1:n_th
                
                d= dec_M(1,H,wf,wm,th,t); 
              val_f(:,H,wf,wm,th,t)= val_matf(:,H,wf,wm,th,t,d); 
              val_m(:,H,wf,wm,th,t)= val_matm(:,H,wf,wm,th,t,d); 
            
            end
        end
    end
end

dec_ud(:,:,:,:,:,t)=dec_M(:,:,:,:,:,t); 
dec(g,:,:,:,:,:,:)=dec_ud(:,:,:,:,:,:); 
pweights(g,:,:,:,:,:,:)=pareto_weight_index(:,:,:,:,:,:); 

U_f(g)=mean(mean(mean(val_f(1,1,:,:,:,lt)))); 
U_m(g)=mean(mean(mean(val_m(1,1,:,:,:,lt)))); 

%--- Singles

Sglmale=zeros(n_eta_m,T); 
Em_s=zeros(n_eta_m,T); 
valm_s=zeros(n_eta_m,T); 
Bm_s=zeros(n_eta_m,1);

Sglfemale=zeros(n_eta_f,T);
Ef_s=zeros(n_eta_f,T); 
valf_s=zeros(n_eta_f,T); 
Bf_s=zeros(n_eta_f,1);

t=T;

for wm=1:n_eta_m
    valm_s(wm,t)=log(0.61*w(det_earn_m(g,:),t,0,lw_grid_m(wm),dt)); 
end

for wf=1:n_eta_f
    valf_s(wf,t)=log(0.61*w(det_earn_f(g,:),t,0,lw_grid_f(wf),dt)); 
end

t=T-1; 

while (t>=lt)
    for wm=1:n_eta_m
            
            Sglmale(wm,t)=log(0.61*w(det_earn_m(g,:),t,0,lw_grid_m(wm),dt)); 
            Bm_s(:)=valm_s(:,t+1);
            Em_s(wm,t)=P_eta_m(wm,:)*Bm_s; 
           valm_s(wm,t)=Sglmale(wm,t)+disc*Em_s(wm,t);
    end
    
    for wf=1:n_eta_f
            Sglfemale(wf,t)=log(0.61*w(det_earn_f(g,:),t,0,lw_grid_f(wf),dt)); 
                             
            Bf_s(:)=valf_s(:,t+1); 
            Ef_s(wf,t)=P_eta_f(wf,:)*Bf_s; 
            valf_s(wf,t)=Sglfemale(wf,t)+disc*Ef_s(wf,t);
    end

t=t-1;

end

U_f0(g)=mean(valf_s(:,lt)) + sigma_f(g); 
U_m0(g)=mean(valm_s(:,lt)) + sigma_m(g); 

end 


U_ftot(1)=exp(U_f(1))+exp(U_f(2))+exp(U_f(3))+exp(U_f0(1));
U_ftot(2)=U_ftot(1);
U_ftot(3)=U_ftot(1);

U_ftot(4)=exp(U_f(4))+exp(U_f(5))+exp(U_f(6))+exp(U_f0(4));
U_ftot(5)=U_ftot(4);
U_ftot(6)=U_ftot(4);

U_ftot(7)=exp(U_f(7))+exp(U_f(8))+exp(U_f(9))+exp(U_f0(7));
U_ftot(8)=U_ftot(7);
U_ftot(9)=U_ftot(7);

U_mtot(1)=exp(U_m(1))+exp(U_m(4))+exp(U_m(7))+exp(U_m0(1));
U_mtot(4)=U_mtot(1);
U_mtot(7)=U_mtot(1);

U_mtot(2)=exp(U_m(2))+exp(U_m(5))+exp(U_m(8))+exp(U_m0(2));
U_mtot(5)=U_mtot(2);
U_mtot(8)=U_mtot(2);

U_mtot(3)=exp(U_m(3))+exp(U_m(6))+exp(U_m(9))+exp(U_m0(3));
U_mtot(6)=U_mtot(3);
U_mtot(9)=U_mtot(3);

end

